import pickle

import torch

from ai.run.runner import GxlBaseRunner
import config
import model_store
import data_handler

if __name__ == '__main__':
    data_handler.prehand_data()
    label_file_path = config.PREHAND_SAVE_DIR + 'token_list.pkl'
    label_list = pickle.load(open(label_file_path, 'rb'))
    model = model_store.GateConv(vocabulary=label_list)
    model.to_train()
    train_idx_file = config.PREHAND_SAVE_DIR + 'prehand_data_train.idx.jsonl'
    dev_idx_file = config.PREHAND_SAVE_DIR + 'prehand_data_dev.idx.jsonl'
    model.fit(device=torch.device('cuda:0'), lr_rate=0.121, batch_size=120, epochs=10000)
